#!/usr/bin/env python3
import os
import re
import sys
import json
import time
import requests
import urllib3
from collections import deque
from urllib.parse import urljoin, urlparse
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
from rich.panel import Panel
from rich import box
from colorama import init

from argus.utils.util import clean_domain_input, ensure_directory_exists, write_to_file
from argus.config.settings import DEFAULT_TIMEOUT, EXPORT_SETTINGS, RESULTS_DIR

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
init(autoreset=True)

console = Console()
TEAL = "#2EC4B6"

PAT_LINK  = re.compile(r'<a[^>]+href=[\'"]([^\'"]+)[\'"]', re.I)
PAT_FORM  = re.compile(r'<form([^>]*)>', re.I)
PAT_INPUT = re.compile(r'<input([^>]*)>', re.I)
PAT_ATTR  = re.compile(r'([a-zA-Z_:-]+)\s*=\s*["\']([^"\']*)["\']')
SENSITIVE_KEYWORDS = ("user", "login", "email", "password", "token", "apikey", "secret")

def banner():
    bar = "=" * 44
    console.print(f"[{TEAL}]{bar}")
    console.print("[cyan]   Argus – Autocomplete Vulnerability Checker")
    console.print(f"[{TEAL}]{bar}\n")

def fetch(url: str, timeout: int):
    try:
        r = requests.get(url, timeout=timeout, verify=False, allow_redirects=True)
        return r.status_code, r.text, r.url
    except:
        return 0, "", url

def crawl(domain: str, timeout: int, max_pages: int):
    queue = deque()
    seen  = set()
    for scheme in ("https", "http"):
        status, html, final = fetch(f"{scheme}://{domain}", timeout)
        if status:
            queue.append((final, status, html))
            seen.add(final)
            break

    pages = []
    with Progress(
        SpinnerColumn(),
        TextColumn("[white]{task.completed}/{task.total} pages"),
        BarColumn(),
        console=console,
        transient=True
    ) as progress:
        task = progress.add_task("Crawling…", total=max_pages)
        while queue and len(pages) < max_pages:
            page_url, status, html = queue.popleft()
            pages.append((page_url, status, html))
            progress.update(task, advance=1)
            if not html:
                continue

            base_netloc = urlparse(page_url).netloc
            for m in PAT_LINK.finditer(html):
                link = urljoin(page_url, m.group(1))
                parsed = urlparse(link)
                if (
                    parsed.scheme in ("http", "https")
                    and parsed.netloc == base_netloc
                    and link not in seen
                ):
                    seen.add(link)
                    s, h, u = fetch(link, timeout)
                    queue.append((u, s, h))

    return pages

def parse_exposures(pages):
    exposures = []
    for page_url, status, html in pages:
        for fm in PAT_FORM.finditer(html):
            attrs = dict(PAT_ATTR.findall(fm.group(1)))
            default_ac = attrs.get("autocomplete", "").lower()
            action     = attrs.get("action", "")
            full_action = urljoin(page_url, action) if action else page_url
            snippet    = html[fm.end():fm.end()+2000]
            for inp in PAT_INPUT.finditer(snippet):
                iattrs = dict(PAT_ATTR.findall(inp.group(1)))
                name    = iattrs.get("name") or iattrs.get("id") or "-"
                typ     = iattrs.get("type","text").lower()
                ac      = iattrs.get("autocomplete", default_ac).lower() or "-"
                risky   = (
                    any(k in name.lower() for k in SENSITIVE_KEYWORDS)
                    or typ in ("password","email","token")
                )
                flag    = "⚠️" if risky and ac not in ("off","new-password") else ""
                exposures.append((page_url, status, full_action, name, ac, typ, flag))
    return exposures

def run(target, threads, opts):
    banner()
    domain    = clean_domain_input(target)
    timeout   = int(opts.get("timeout", DEFAULT_TIMEOUT))
    max_pages = int(opts.get("max_pages", 25))

    console.print(f"[white]* Target:[bold]{domain}[/bold]  timeout: {timeout}s  pages: {max_pages}[/white]\n")
    start = time.time()

    pages     = crawl(domain, timeout, max_pages)
    exposures = parse_exposures(pages)

    if not exposures:
        console.print("[green][*] No autocomplete exposures found.[/green]")
    else:
        tbl = Table(
            title=f"Autocomplete Vulnerabilities – {domain}",
            header_style="bold magenta",
            box=box.MINIMAL
        )
        for h, style in [
            ("Page URL","cyan"), ("HTTP Status","green"), ("Form Action","yellow"),
            ("Field Name","white"), ("Autocomplete","blue"), ("Type","magenta"), ("Risk","red")
        ]:
            tbl.add_column(h, style=style, overflow="fold", justify="right" if h=="HTTP Status" else None)

        for row in exposures:
            tbl.add_row(*map(str, row))

        console.print(tbl)
        summary = {
            "Pages scanned": len(pages),
            "Fields checked": len(exposures),
            "Exposures": sum(1 for r in exposures if r[-1]=="⚠️"),
            "Elapsed": f"{time.time()-start:.2f}s"
        }
        console.print(Panel(
            "\n".join(f"[bold]{k}:[/bold] {v}" for k,v in summary.items()),
            title="Summary", style="bold white"
        ))

    console.print("[white][*] Autocomplete check completed[/white]\n")

    if EXPORT_SETTINGS.get("enable_txt_export"):
        out_dir = os.path.join(RESULTS_DIR, domain)
        ensure_directory_exists(out_dir)
        recorder = Console(record=True, width=console.width)
        if exposures:
            recorder.print(tbl)
            recorder.print(Panel(
                "\n".join(f"{k}: {v}" for k,v in summary.items()),
                title="Summary", style="bold white"
            ))
        else:
            recorder.print("[green]No exposures found[/green]")
        write_to_file(os.path.join(out_dir, "autocomplete_check.txt"), recorder.export_text())

if __name__ == "__main__":
    if len(sys.argv) < 2:
        console.print("[red]✖ No target provided.[/red]")
        sys.exit(1)
    tgt  = sys.argv[1]
    thr  = int(sys.argv[2]) if len(sys.argv)>2 and sys.argv[2].isdigit() else 1
    opts = json.loads(sys.argv[3]) if len(sys.argv)>3 else {}
    run(tgt, thr, opts)
